package parser

import (
	"bufio"
	"bytes"
	"coffelang/compiler/ast"
	"fmt"
	"io"
	"strconv"
	"strings"
)

const EOF = -1
const whitespace1 = 1<<'\t' | 1<<' '
const whitespace2 = 1<<'\t' | 1<<'\n' | 1<<'\r' | 1<<' '

func (e *Error) Error() string {
	pos := e.Pos
	if pos.Line == EOF {
		return fmt.Sprintf("%v at EOF:   %s\n", pos.Source, e.Message)
	} else {
		return fmt.Sprintf("%v line:%d(column:%d) near '%v':   %s\n", pos.Source, pos.Line, pos.Column, e.Token, e.Message)
	}
}

type Error struct {
	Pos     ast.Position
	Message string
	Token   string
}

type Scanner struct {
	Pos    ast.Position
	reader *bufio.Reader
}

// 扫描器的构造函数
func NewScanner(reader io.Reader, source string) *Scanner {
	return &Scanner{
		Pos: ast.Position{
			Source: source,
			Line:   1,
			Column: 0,
		},
		reader: bufio.NewReaderSize(reader, 4096),
	}
}

func (sc *Scanner) Error(token string, message string) *Error {
	return &Error{
		Token:   token,
		Message: message,
		Pos:     ast.Position{},
	}
}

func (sc *Scanner) readNext() int {
	ch, err := sc.reader.ReadByte()
	if err == io.EOF {
		return EOF
	}
	return int(ch)
}

func (sc *Scanner) Peek() int {
	ch := sc.readNext()
	if ch != EOF {
		sc.reader.UnreadByte()
	}
	return ch
}

func (sc *Scanner) NewLine(ch int) {
	if ch < 0 {
		return
	}
	sc.Pos.Line += 1
	sc.Pos.Column = 0
	next := sc.Peek()
	if ch == '\r' && next == '\n' {
		sc.reader.ReadByte()
	}
}

func (sc *Scanner) Next() int {
	ch := sc.readNext()

	switch ch {
	case '\n', '\r':
		sc.NewLine(ch)
		ch = int('\n')
	case EOF:
		sc.Pos.Line = EOF
		sc.Pos.Column = 0
	default:
		sc.Pos.Column++
	}
	return ch
}

func (sc *Scanner) skipWhiteSpace(whitespace int64) int {
	ch := sc.Next()
	for ; whitespace&(1<<uint(ch)) != 0; ch = sc.Next() {
	}
	return ch
}

func (sc *Scanner) scanIdent(ch int, buf *bytes.Buffer) error {
	writeChar(buf, ch)
	for isIdent(sc.Peek(), 1) {
		writeChar(buf, sc.Next())
	}
	return nil
}

func (sc *Scanner) scanVar(ch int, buf *bytes.Buffer) error {
	chNew := sc.Peek()
	if chNew != int('{') {
		return sc.Error(buf.String(), "变量定义不合法")
	}
	sc.Next()
	for isIdent(sc.Peek(), 1) {
		writeChar(buf, sc.Next())
	}
	chNew = sc.Peek()
	if chNew != int('}') {
		return sc.Error(buf.String(), "变量定义不合法")
	}
	sc.Next()

	return nil
}

func (sc *Scanner) scanDecimal(ch int, buf *bytes.Buffer) error {
	writeChar(buf, ch)
	for isDecimal(sc.Peek()) {
		writeChar(buf, sc.Next())
	}
	return nil
}

func (sc *Scanner) scanNumber(ch int, buf *bytes.Buffer) error {
	if ch == '0' { // octal
		if sc.Peek() == 'x' || sc.Peek() == 'X' {
			writeChar(buf, ch)
			writeChar(buf, sc.Next())
			hasvalue := false
			for isDigit(sc.Peek()) {
				writeChar(buf, sc.Next())
				hasvalue = true
			}
			if !hasvalue {
				return sc.Error(buf.String(), "illegal hexadecimal number")
			}
			return nil
		} else if sc.Peek() != '.' && isDecimal(sc.Peek()) {
			ch = sc.Next()
		}
	}
	sc.scanDecimal(ch, buf)
	if sc.Peek() == '.' {
		sc.scanDecimal(sc.Next(), buf)
	}
	if ch = sc.Peek(); ch == 'e' || ch == 'E' {
		writeChar(buf, sc.Next())
		if ch = sc.Peek(); ch == '-' || ch == '+' {
			writeChar(buf, sc.Next())
		}
		sc.scanDecimal(sc.Next(), buf)
	}

	return nil
}

func (sc *Scanner) scanString(quote int, buf *bytes.Buffer) error {
	ch := sc.Next()
	for ch != quote {
		if ch == '\n' || ch == '\r' || ch < 0 {
			return sc.Error(buf.String(), "unterminated string")
		}
		if ch == '\\' {
			if err := sc.scanEscape(ch, buf); err != nil {
				return err
			}
		} else {
			writeChar(buf, ch)
		}
		ch = sc.Next()
	}
	return nil
}

func (sc *Scanner) scanEscape(ch int, buf *bytes.Buffer) error {
	ch = sc.Next()
	switch ch {
	case 'a':
		buf.WriteByte('\a')
	case 'b':
		buf.WriteByte('\b')
	case 'f':
		buf.WriteByte('\f')
	case 'n':
		buf.WriteByte('\n')
	case 'r':
		buf.WriteByte('\r')
	case 't':
		buf.WriteByte('\t')
	case 'v':
		buf.WriteByte('\v')
	case '\\':
		buf.WriteByte('\\')
	case '"':
		buf.WriteByte('"')
	case '\'':
		buf.WriteByte('\'')
	case '\n':
		buf.WriteByte('\n')
	case '\r':
		buf.WriteByte('\n')
		sc.NewLine('\r')
	default:
		if '0' <= ch && ch <= '9' {
			bytes := []byte{byte(ch)}
			for i := 0; i < 2 && isDecimal(sc.Peek()); i++ {
				bytes = append(bytes, byte(sc.Next()))
			}
			val, _ := strconv.ParseInt(string(bytes), 10, 32)
			writeChar(buf, int(val))
		} else {
			writeChar(buf, ch)
		}
	}
	return nil
}

func (sc *Scanner) countSep(ch int) (int, int) {
	count := 0
	for ; ch == '='; count = count + 1 {
		ch = sc.Next()
	}
	return count, ch
}

func (sc *Scanner) scanMultilineString(ch int, buf *bytes.Buffer) error {
	var count1, count2 int
	count1, ch = sc.countSep(ch)
	if ch != '[' {
		return sc.Error(string(ch), "invalid multiline string")
	}
	ch = sc.Next()
	if ch == '\n' || ch == '\r' {
		ch = sc.Next()
	}
	for {
		if ch < 0 {
			return sc.Error(buf.String(), "unterminated multiline string")
		} else if ch == ']' {
			count2, ch = sc.countSep(sc.Next())
			if count1 == count2 && ch == ']' {
				goto finally
			}
			buf.WriteByte(']')
			buf.WriteString(strings.Repeat("=", count2))
			continue
		}
		writeChar(buf, ch)
		ch = sc.Next()
	}

finally:
	return nil
}

var reservedWords = map[string]int{
	//"and": TAnd, "break": TBreak, "do": TDo, "else": TElse, "elseif": TElseIf,
	//"end": TEnd, "false": TFalse, "for": TFor, "function": TFunction,
	//"if": TIf, "in": TIn, "local": TLocal, "nil": TNil, "not": TNot, "or": TOr,
	//"return": TReturn, "repeat": TRepeat, "then": TThen, "true": TTrue,
	//"until": TUntil, "while": TWhile
}

// 扫描
func (sc *Scanner) Scan(lexer *Lexer) (ast.Token, error) {
	var err error
	tok := ast.Token{}
	newline := false

	// 忽略空白
	ch := sc.skipWhiteSpace(whitespace1)
	if ch == '\n' || ch == '\r' {
		newline = true
		ch = sc.skipWhiteSpace(whitespace2)
	}

	if ch == '(' && lexer.PreTokenType == ')' {
		lexer.PNewLine = newline
	} else {
		lexer.PNewLine = false
	}

	var _buf bytes.Buffer
	buf := &_buf
	tok.Pos = sc.Pos

	switch {
	//case isIdent(ch, 0):
	//// 标识符
	//tok.Type = Token_Type_Ident
	//err = sc.scanIdent(ch, buf)
	//tok.Str = buf.String()
	//if err != nil {
	//goto finally
	//}
	//if typ, ok := reservedWords[tok.Str]; ok {
	//tok.Type = typ
	//}
	case isDecimal(ch):
		// 10进制数字
		tok.Type = Token_Type_Number
		err = sc.scanNumber(ch, buf)
		tok.Str = buf.String()
	default:
		switch ch {
		case EOF:
			tok.Type = EOF
		case '>':
			if sc.Peek() == '=' {
				tok.Type = Token_Type_Assign
				tok.Str = ">="
				sc.Next()
			} else {
				tok.Type = ch
				tok.Str = string(ch)
			}
		case '+', '*', '/', '%', '^', '#', '(', ')', '{', '}', ']', ';', ':', ',':
			tok.Type = ch
			tok.Str = string(ch)
		default:
			writeChar(buf, ch)
			err = sc.Error(buf.String(), "Invalid token")
			goto finally
		}
	}

finally:
	tok.Name = TokenName(int(tok.Type))
	return tok, err
}

func TokenName(c int) string {
	return string([]byte{byte(c)})
}

// 扫描
func (sc *Scanner) ScanT() (ast.Token, error) {
	var err error
	tok := ast.Token{}

	// 忽略空白，并读取1个字符
	ch := sc.skipWhiteSpace(whitespace1)

	var _buf bytes.Buffer
	buf := &_buf
	tok.Pos = sc.Pos

	switch {
	case isDelimiter(ch):
		// 定界符
		tok.Type = Token_Type_Delimiter
		tok.Str = string(ch)
	case isDecimal(ch):
		// 10进制数字
		tok.Type = Token_Type_Number
		err = sc.scanNumber(ch, buf)
		tok.Str = buf.String()
	case isVar(ch, 0):
		// 变量
		tok.Type = Token_Type_Var
		err = sc.scanVar(ch, buf)
		tok.Str = buf.String()
		if err != nil {
			goto finally
		}
	default:
		switch ch {
		case EOF:
			tok.Type = EOF
		case '>':
			tok.Type = Token_Type_Assign
			tok.Str = string(ch)
		case '+', '-':
			tok.Type = ch
			tok.Str = string(ch)
		default:
			writeChar(buf, ch)
			err = sc.Error(buf.String(), "Invalid token")
			goto finally
		}
	}

finally:
	tok.Name = TokenName(int(tok.Type))
	return tok, err
}
